Gamma regression for blood clotting#

As a preliminary example of this package’s functionality, we provide an example of performing a Gamma regression, which is used when the response variable is continuous and positive. We have adapted the following canonical example of a Gamma regression from McCullagh & Nelder (1989).

Nine different percentage concentrations with prothrombin-free plasma (\(u\)) and clotting was induced via two lots of thromboplastin. Previous researchers had fitted a hyperbolic model, using an inverse transformation of the data for both lots \(1\) and \(2\), but we will analyze both lots using the inverse link and Gamma family.

The following initial plots hint at using a log scale for \(u\) to achieve inverse linearity, as well as the fact that the two lots have different regression and intercept coefficients.

[1]:
from scikit_stan import GLM

import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt
[3]:
# ATTRIBUTION: McCullagh & Nelder (1989), chapter 8.4.2 p 301-302
bcdata_dict = {
    "u": np.array([5, 10, 15, 20, 30, 40, 60, 80, 100]),
    "lot1": np.array([118, 58, 42, 35, 27, 25, 21, 19, 18]),
    "lot2": np.array([69, 35, 26, 21, 18, 16, 13, 12, 12]),
}
bc_data_X = np.log(bcdata_dict["u"])
bc_data_lot1 = bcdata_dict["lot1"]
bc_data_lot2 = bcdata_dict["lot2"]
[4]:
l1, = plt.plot(bcdata_dict["u"], bcdata_dict["lot1"], "o", label="lot 1")
l2, = plt.plot(bcdata_dict["u"], bcdata_dict["lot2"], "o", label="lot 2")

plt.suptitle("Mean Clotting Times vs Plasma Concentration")
plt.xlabel('Normal Plasma Concentration')
plt.ylabel('Blood Clotting Time')

plt.legend(handles=[l1, l2])
/home/docs/checkouts/readthedocs.org/user_builds/scikit-stan/conda/latest/lib/python3.9/site-packages/traitlets/traitlets.py:3278: FutureWarning: --rc={'figure.dpi': 96} for dict-traits is deprecated in traitlets 5.0. You can pass --rc <key=value> ... multiple times to add items to a dict.
  warn(
[4]:
<matplotlib.legend.Legend at 0x7f6b4a814610>
../_images/examples_Gamma_Bloodclotting_5_2.svg
[5]:
l1, = plt.plot(bc_data_X, bc_data_lot1, "o", label="lot 1")
l2, = plt.plot(bc_data_X, bc_data_lot2, "o", label="lot 2")

plt.suptitle("Mean Clotting Times vs Plasma Concentration")
plt.xlabel('Normal Plasma Concentration')
plt.ylabel('Blood Clotting Time')

plt.legend(handles=[l1, l2])
[5]:
<matplotlib.legend.Legend at 0x7f6b42727fa0>
../_images/examples_Gamma_Bloodclotting_6_1.svg

After this preliminary data analysis, we fit two lines to the two lots of data. Using \(x = \log u\), we fit a GLM to the data.

The original results were as follows, and we recreate regression coefficients within a standard deviation of these values:

\[\text{lot 1:} \quad \hat{\mu} ^{-1} = - 0.01655(\pm 0.00086) + 0.01534(\pm 0.00143)x\]
\[\text{lot 2:} \quad \hat{\mu} ^{-1} = - 0.02391(\pm 0.00038) + 0.02360(\pm 0.00062)x\]

As in previous work, we will fit two different linear models for each lot in the dataset. As usual, the \(\alpha\) parameter is the regression intercept and \(\mathbf{\beta}\) is vector of regression coefficients and the parameter \(\sigma\) represents an auxiliary variable for the model. In this case, \(\sigma\) is the shape parameter for the Gamma distribution.

[6]:
# Initialize two different GLM objects, one for each lot.
glm_gamma1 = GLM(family="gamma", link="inverse", seed=1234)
glm_gamma2 = GLM(family="gamma", link="inverse", seed=1234)

# Fit the model. Note that default priors are used without autoscaling, see the
# API to see how to change these.
glm_gamma1.fit(bc_data_X, bc_data_lot1, show_console=False)
glm_gamma2.fit(bc_data_X, bc_data_lot2, show_console=False)

print(glm_gamma1.alpha_, glm_gamma1.beta_)
print(glm_gamma2.alpha_, glm_gamma2.beta_)
/home/docs/checkouts/readthedocs.org/user_builds/scikit-stan/checkouts/latest/scikit_stan/utils/validation.py:263: UserWarning: Passed data is one-dimensional, while estimator expects it to be at at least two-dimensional.
  warnings.warn(
/home/docs/checkouts/readthedocs.org/user_builds/scikit-stan/checkouts/latest/scikit_stan/generalized_linear_regression/glm.py:486: UserWarning: Prior on intercept not specified. Using default prior.
                alpha ~ normal(mu(y), 2.5 * sd(y)) if Gaussian family else normal(0, 2.5)
  warnings.warn(
/home/docs/checkouts/readthedocs.org/user_builds/scikit-stan/checkouts/latest/scikit_stan/generalized_linear_regression/glm.py:532: UserWarning: Prior on auxiliary parameter not specified. Using default unscaled prior
                        sigma ~ exponential(1)

  warnings.warn(
14:09:41 - cmdstanpy - INFO - CmdStan start processing

14:09:41 - cmdstanpy - INFO - CmdStan done processing.
14:09:41 - cmdstanpy - WARNING - Some chains may have failed to converge.
        Chain 1 had 20 divergent transitions (2.0%)
        Chain 2 had 39 divergent transitions (3.9%)
        Chain 3 had 26 divergent transitions (2.6%)
        Chain 4 had 34 divergent transitions (3.4%)
        Use function "diagnose()" to see further information.
14:09:41 - cmdstanpy - INFO - CmdStan start processing


14:09:42 - cmdstanpy - INFO - CmdStan done processing.
14:09:42 - cmdstanpy - WARNING - Some chains may have failed to converge.
        Chain 1 had 8 divergent transitions (0.8%)
        Chain 2 had 37 divergent transitions (3.7%)
        Chain 3 had 17 divergent transitions (1.7%)
        Chain 4 had 8 divergent transitions (0.8%)
        Use function "diagnose()" to see further information.

[-0.01423061] [0.01500947]
[-0.01988932] [0.02298595]

As can be seen above, the fitted model has the following parameters, which are within one standard deviation of the results from past studies.

\[\text{lot 1:} \quad \hat{\mu} ^{-1} = - 0.01437 + 0.01511 \cdot x\]
\[\text{lot 2:} \quad \hat{\mu} ^{-1} = - 0.02016 + 0.02301 \cdot x\]

As a verification of the accuracy of the fitted model, we can plot the fitted lines and the data.

[7]:
mu_inv1 = 1 /( glm_gamma1.alpha_ + glm_gamma1.beta_ * bc_data_X)
mu_inv2 = 1 /( glm_gamma2.alpha_ + glm_gamma2.beta_ * bc_data_X)
[8]:
mlot1, = plt.plot(bc_data_X, mu_inv1, "r", label="mu_inv lot 1")
mlot2, = plt.plot(bc_data_X, mu_inv2, "b", label="mu_inv lot 2")
l1, = plt.plot(bc_data_X, bc_data_lot1, "o", label="lot1")
l2, = plt.plot(bc_data_X, bc_data_lot2, "o", label="lot2")

plt.suptitle("Mean Clotting Times vs Plasma Concentration")
plt.xlabel('Normal Plasma Concentration')
plt.ylabel('Blood Clotting Time')

plt.legend(handles=[mlot1, mlot2, l1, l2])
[8]:
<matplotlib.legend.Legend at 0x7f6b40d6cfd0>
../_images/examples_Gamma_Bloodclotting_11_1.svg

As this package is a wrapper around CmdStanPy, we can gather additional statistics about the fitted model with methods from that package. In particular, we can consider further statistics about the model by using CmdStanPy’s summary method on the results of the fit.

Notice that \(\mu\) (“mu”) and the link-inverted \(\mu\) (“mu unlinked”) are included as part of the model summary.

[9]:
glm_gamma1.fitted_samples_.summary()
[9]:
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -37.778600 0.046785 1.403670 -40.522200 -37.400600 -36.293900 900.1620 2719.52000 1.00412
alpha[1] -0.014231 0.000337 0.010197 -0.029048 -0.015165 0.003691 914.9120 2764.08000 1.00151
beta[1] 0.015010 0.000148 0.004417 0.007862 0.015049 0.022065 889.3810 2686.95000 1.00240
sigma 4.691710 0.053930 2.090850 1.836280 4.359500 8.560740 1502.9555 4540.65104 0.99991
[10]:
glm_gamma2.fitted_samples_.summary()
[10]:
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -33.593000 0.052207 1.44186 -36.442000 -33.230700 -32.049700 762.7800 2067.16000 1.00250
alpha[1] -0.019889 0.000554 0.01682 -0.045271 -0.020698 0.009217 921.2220 2496.54000 1.00569
beta[1] 0.022986 0.000231 0.00703 0.011729 0.022965 0.034812 923.2150 2501.94000 1.00428
sigma 4.656330 0.059330 2.08698 1.825410 4.360940 8.579550 1237.3025 3353.12332 1.00036

Additional information about the model and various visualizations can be revealed by Arviz, which seamlessly integrates with CmdStanPy components. Consider the following.

[11]:
import arviz as az
az.style.use("arviz-darkgrid")
[12]:
infdata = az.from_cmdstanpy(glm_gamma1.fitted_samples_)
infdata
[12]:
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 1000, alpha_dim_0: 1, beta_dim_0: 1)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * alpha_dim_0  (alpha_dim_0) int64 0
        * beta_dim_0   (beta_dim_0) int64 0
      Data variables:
          alpha        (chain, draw, alpha_dim_0) float64 -0.01532 ... -0.009792
          beta         (chain, draw, beta_dim_0) float64 0.01482 0.01114 ... 0.01252
          sigma        (chain, draw) float64 5.693 4.149 6.202 ... 5.256 3.381 8.383
      Attributes:
          created_at:                 2022-10-05T14:09:43.331096
          arviz_version:              0.12.1
          inference_library:          cmdstanpy
          inference_library_version:  1.0.7

    • <xarray.Dataset>
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 0 1 2 3
        * draw             (draw) int64 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 -36.11 -36.88 ... -36.73 -37.15
          acceptance_rate  (chain, draw) float64 0.9991 0.9936 ... 0.9803 0.9989
          step_size        (chain, draw) float64 0.1811 0.1811 ... 0.2019 0.2019
          tree_depth       (chain, draw) int64 3 3 4 2 3 3 4 3 3 ... 3 4 4 4 3 2 3 3 4
          n_steps          (chain, draw) int64 15 11 15 5 7 7 15 ... 15 15 7 15 15 15
          diverging        (chain, draw) bool False False False ... False False False
          energy           (chain, draw) float64 36.58 37.17 38.51 ... 37.23 37.33
      Attributes:
          created_at:                 2022-10-05T14:09:43.336452
          arviz_version:              0.12.1
          inference_library:          cmdstanpy
          inference_library_version:  1.0.7

[13]:
az.plot_posterior(infdata, var_names=['alpha', 'beta', 'sigma']);
../_images/examples_Gamma_Bloodclotting_18_0.svg
[14]:
az.plot_trace(infdata, var_names=['alpha', 'beta', 'sigma'], compact=True);
../_images/examples_Gamma_Bloodclotting_19_0.svg